from psynet.trial.non_adaptive import (
    NonAdaptiveTrialMaker,
    NonAdaptiveTrial,
    StimulusSet,
    StimulusSpec,
    StimulusVersionSpec
)

import csv
from os import listdir
from os.path import isfile, join, splitext
import json
import numpy as np
import random

# random.seed(27)
# SAMPLE_BATCH_SIZE = 100
STIMULUS_PATH = "validation_data"
BLOCKS = ["pleasant"]

TAGS = [f for f in listdir(STIMULUS_PATH) if (isfile(join(STIMULUS_PATH, f)) and (f != ".DS_Store"))] # include only files


def tag_to_vers(tag):
    specs = splitext(tag)[0] # remove extension
    [phase,source,frame] = specs.split("_")

    if not (phase=="II" and source=="mode"):
        with open(STIMULUS_PATH + '/' + tag, newline='') as csvfile:
            versions = csv.reader(csvfile,delimiter=',')
            index = 0
            stim_list = []
            for row in versions:
                if (index == 0):
                    index = index + 1
                else:
                    stim_list = stim_list + [
                        [{
                            "intervals": [json.loads(row[0]), json.loads(row[1])],
                            "id": specs + "_" + str(index)
                        }]
                    ]
                    index = index + 1
        # if (source=="sample"):
        #     stim_list = random.choices(stim_list, k=SAMPLE_BATCH_SIZE)
    else:
        with open(STIMULUS_PATH + '/' + tag, newline='') as csvfile:
            versions = csv.reader(csvfile,delimiter=',')
            row_count = sum(1 for row in versions)
            stim_list = [[] for _ in range(int((row_count-1)/3))]

        with open(STIMULUS_PATH + '/' + tag, newline='') as csvfile:
            versions = csv.reader(csvfile,delimiter=',')
            index = 0
            for row in versions:
                if (index == 0):
                    index = index + 1
                else:
                    [v1,v2,variant,group] = row
                    current_list = stim_list[int(group)-1]
                    current_list = current_list + [{
                        "intervals": [json.loads(v1), json.loads(v2)],
                        "variant": variant,
                        "id": specs + "_" + str(index)
                    }]
                    np.random.shuffle(current_list) #shuffle order of variants (to randomize NAFC)
                    stim_list[int(group)-1] = current_list
                    index = index + 1

    return stim_list

stimulus_set = StimulusSet([
    StimulusSpec(
        definition={"tag": tag},
        version_specs=[
            StimulusVersionSpec(
                definition={"version": version}
            )
            for version in tag_to_vers(tag)
        ],
        phase="experiment",
        block=block
    )
    for tag in TAGS
    for block in BLOCKS
])